# Libraries
library(tidyverse)
library(here)
library(keras)
library(caret)
library(UBL)
library(gmRi)

#mills_path <- shared.path(os.use = "windows", group = "Mills Lab", folder = NULL)

#AK setup
mills_path <- shared.path(os.use = "unix", group = "Mills Lab", folder = NULL)
reticulate::use_condaenv("rkeras2020")

theme_set(theme_bw())

filter and preprocess data

# NEFSC bottom trawl
load(str_c(mills_path, "Data/Survdat_Nye_allseason.RData"))

# ID format
survdat$ID<-format(survdat$ID,scientific=FALSE)

# filter years
dat<-survdat%>%filter(EST_YEAR%in%c(1982:2015))

# filter seasons
dat<-dat%>%filter(SEASON%in%c("SPRING","FALL"))

# offshore strata to include are (starts with 1 ends with 0, so offshore strata number 13 is 1130)
# Georges Bank: 1130 1140 1150 1160 1170 1180 1190 1200 1210 1220 1230 1240 1250
# Gulf of Maine: 1260 1270 1280 1290 1300 1360 1370 1380 1390 1400
# inshore strata to include are (starts with 3 ends with 0, so inshore strata 61 is 3610)
# Georges Bank: 3560
# Gulf of Maine: 3580 3590 3600 3610 3630 3640 3650 3660
strata<-c(1130,1140,1150,1160,1170,1180,1190,1200,1210,1220,1230,1240,1250,1260,1270,1280,
          1290,1300,1360,1370,1380,1390,1400,3560,3580,3590,3600,3610,3630,3640,3650,3660)
dat<-dat%>%filter(STRATUM%in%strata)

# for the years, seasons, and strata listed, which rows in data are cod?
cod<-dat%>%filter(COMNAME=="ATLANTIC COD")

# one row for each ID
dat<-dat%>%distinct(ID,.keep_all=TRUE)

# tow with cod
codtow<-cod%>%distinct(ID,.keep_all=TRUE)

# remove rows if NA for lengths or abundance
cod<-cod%>%filter(!is.na(NUMLEN))
cod<-cod%>%filter(!is.na(ABUNDANCE))
cod<-cod%>%filter(!is.na(LENGTH))

# remove rows if NA for lengths or abundance
codtow<-codtow%>%filter(!is.na(NUMLEN))
codtow<-codtow%>%filter(!is.na(ABUNDANCE))
codtow<-codtow%>%filter(!is.na(LENGTH))

# one row for each NUMLEN
cod<-uncount(cod,NUMLEN)

three size classes

# bin lengths to 3 groups: small <20, medium 20-60, large >60
cod$SIZE<-ifelse(cod$LENGTH<20,"small",ifelse(cod$LENGTH>60,"large","medium"))
cod$NUM<-1

# for each ID, count number of small, medium, and large
x<-cod%>%pivot_wider(id_cols=ID,names_from=SIZE,values_from=NUM,values_fn=list(NUM=sum))

# join with codtow
codtow<-left_join(codtow,x,by="ID")

# remove LENGTH, NUMLEN
codtow<-select(codtow,-c(LENGTH,NUMLEN))

# join tows with cod with tows without cod
x<-left_join(dat,codtow,by="ID")

# select columns to keep
x<-select(x,c(1:29,35:43,74,75,76,77,78,88,89,90))

# if COMNAME is NA, fill-in biomass,abundance,small,medium,large with 0
notcod<-which(is.na(x$COMNAME.y))
x$BIOMASS.y[notcod]<-0
x$ABUNDANCE.y[notcod]<-0
x$small[notcod]<-0
x$medium[notcod]<-0
x$large[notcod]<-0

# for tows with cod, fill-in size category abundance NA with 0
x$small[is.na(x$small)]<-0
x$medium[is.na(x$medium)]<-0
x$large[is.na(x$large)]<-0

# proportion of relative abundance allocated to each size class
df<-data.frame(matrix(NA,nrow=nrow(x),ncol=3))
for(i in 1:nrow(x)){df[i,]<-x$ABUNDANCE.y[i]*c(x$small[i],x$medium[i],x$large[i])/sum(x$small[i],x$medium[i],x$large[i])}
x<-bind_cols(x,df)
colnames(x)[47:49]<-c("nsmall","nmedium","nlarge")
x$nsmall[is.na(x$nsmall)]<-0
x$nmedium[is.na(x$nmedium)]<-0
x$nlarge[is.na(x$nlarge)]<-0

# proportion of relative biomass allocated to each size class
dfb<-data.frame(matrix(NA,nrow=nrow(x),ncol=3))
for(i in 1:nrow(x)){dfb[i,]<-x$BIOMASS.y[i]*c(x$small[i],x$medium[i],x$large[i])/sum(x$small[i],x$medium[i],x$large[i])}
x<-bind_cols(x,dfb)
colnames(x)[50:52]<-c("bsmall","bmedium","blarge")
x$bsmall[is.na(x$bsmall)]<-0
x$bmedium[is.na(x$bmedium)]<-0
x$blarge[is.na(x$blarge)]<-0

area of strata

# bring in strata area
strataarea<-read_csv(str_c(mills_path, "Projects/NSF_CAccel/Data/strata area.csv"))
strataarea<-select(strataarea,c(area,stratum))
colnames(x)[25]<-"stratum"
x<-left_join(x,strataarea,by="stratum")

# column names without .x or .y at the end
colnames(x)<-str_split(string=colnames(x),pattern="[.]",simplify=TRUE)[,1]

# Assign strata to area
# Georges Bank: 1130 1140 1150 1160 1170 1180 1190 1200 1210 1220 1230 1240 1250 3560
# Gulf of Maine: 1260 1270 1280 1290 1300 1360 1370 1380 1390 1400 3580 3590 3600 3610 3630 3640 3650 3660
# specify prior area columns (statistical area and stratum size area)
colnames(x)[c(30,53)]<-c("STATAREA","STRATUMAREA")
GBstrata<-c(1130,1140,1150,1160,1170,1180,1190,1200,1210,1220,1230,1240,1250,3560)
x$AREA<-ifelse(x$stratum%in%GBstrata,"GB","GoM")

# calculate annual area/season/size class mean abundance and biomass within strata
# area: GoM, GB
# season: spring, fall
# size class: small, medium, large
x%>%
  group_by(EST_YEAR,AREA,SEASON,stratum)%>%
  summarise(mnsmall=mean(nsmall),mnmedium=mean(nmedium),mnlarge=mean(nlarge),
            mnbsmall=mean(bsmall),mnbmedium=mean(bmedium),mnblarge=mean(blarge),
            STRATUMAREA=mean(STRATUMAREA))->q

# calculate stratum weights by stratum area
q%>%
  group_by(EST_YEAR,AREA,SEASON)%>%
  mutate(weight=STRATUMAREA/(sum(STRATUMAREA)))->q

# calculate annual area/season/size class mean abundance and biomass across strata
q%>%
  group_by(EST_YEAR,AREA,SEASON)%>%
  summarise(abundance_small=weighted.mean(mnsmall,weight),
            abundance_medium=weighted.mean(mnmedium,weight),
            abundance_large=weighted.mean(mnlarge,weight),
            biomass_small=weighted.mean(mnbsmall,weight),
            biomass_medium=weighted.mean(mnbmedium,weight),
            biomass_large=weighted.mean(mnblarge,weight))->p

# long format for plots
p%>%
  pivot_longer(
    cols=4:9,
    names_to=c("type","size"),
    names_patter="(.*)_(.*)",
    values_to="value"
  )->a

stratified abundance

# plot abundance
a%>%
  filter(type=="abundance")%>%
  ggplot(aes(x=EST_YEAR,y=value,group=size,color=size))+
  geom_line(size=1)+
  facet_grid(AREA+SEASON~.,scales="free")+
  labs(title="stratified mean abundance",x="year",y="mean abundance per tow across strata")

survey data for year i, i-1, i-2, and i-3

# data frame to store stratified survey year i AND year i-1, i-2, and i-3
# column 1 is year
# other columns will be ordered by area (GoM, GB), season (spring, fall), type (abundance, biomass), size classes s/m/l
# 1 + (2 x 2 x 2 x 3) = 25
survdf<-data.frame(matrix(NA,nrow=34,ncol=25))
colnames(survdf)<-c("year","gom_spr_abun_sml","gom_spr_abun_med","gom_spr_abun_lrg",
                           "gom_spr_bio_sml","gom_spr_bio_med","gom_spr_bio_lrg",
                           "gom_fal_abun_sml","gom_fal_abun_med","gom_fal_abun_lrg",
                           "gom_fal_bio_sml","gom_fal_bio_med","gom_fal_bio_lrg",
                           "gb_spr_abun_sml","gb_spr_abun_med","gb_spr_abun_lrg",
                           "gb_spr_bio_sml","gb_spr_bio_med","gb_spr_bio_lrg",
                           "gb_fal_abun_sml","gb_fal_abun_med","gb_fal_abun_lrg",
                           "gb_fal_bio_sml","gb_fal_bio_med","gb_fal_bio_lrg")

# add year
survdf[,1]<-sort(unique(p$EST_YEAR))

# order of rows from p are 4:1 when subset by year
pdf<-as.data.frame(p)
for(i in 1:34){
  survdf[i,2:25]<-as.numeric(unlist(c(subset(pdf,pdf$EST_YEAR==survdf$year[i])[4,4:9],
                                      subset(pdf,pdf$EST_YEAR==survdf$year[i])[3,4:9],
                                      subset(pdf,pdf$EST_YEAR==survdf$year[i])[2,4:9],
                                      subset(pdf,pdf$EST_YEAR==survdf$year[i])[1,4:9])))
}

# now make columns for year i-1, i-2, and i-3
# catch data limited to earliest year 1982 so first year with all three is 1985
# year i-1
yi1<-rbind(rep(NaN,25),survdf[1:33,-1])
colnames(yi1)<-paste(colnames(yi1),"i1",sep="_")

# year i-2
yi2<-rbind(rep(NaN,25),rep(NaN,25),survdf[1:32,-1])
colnames(yi2)<-paste(colnames(yi2),"i2",sep="_")

# year i-3
yi3<-rbind(rep(NaN,25),rep(NaN,25),rep(NaN,25),survdf[1:31,-1])
colnames(yi3)<-paste(colnames(yi3),"i3",sep="_")

# bind columns together
survdf<-bind_cols(survdf,yi1,yi2,yi3)

# reorder columns all abundance together all biomass together
survdf<-survdf[,c(str_which(colnames(survdf),pattern="bio",negate=TRUE),
                  str_which(colnames(survdf),pattern="bio",negate=FALSE))]

# join with x dataframe by EST_YEAR and year
towsurv<-left_join(x,survdf,by=c("EST_YEAR"="year"))

# columns for tow information
moddf<-select(towsurv,c(ID,EST_YEAR,SEASON,SVVESSEL,TOWDUR,AVGDEPTH,stratum,nsmall,nmedium,nlarge,AREA))

# stratified mean values
stratmnvals<-select(towsurv,c(55:102))

SST at trawl locations at year i, i-1, and regional SST at year i-1

###############################################################################################
# now to add in SST for each trawl ID location for year i
trawltemp<-read_csv(str_c(mills_path, "Projects/NSF_CAccel/Data/TrawlTemperatures2.csv"),
                    col_names=c("ID","tempK","tempK10","anom","anom10"))
trawltemp$ID<-format(trawltemp$ID,scientific=FALSE)

# convert Kelvin to Celsius
trawltemp<-trawltemp%>%mutate(tempC=tempK-273.15,tempC10=tempK10-273.15)

# join with widedat by ID
widedat<-left_join(moddf,trawltemp[,-c(2,3,6,7)],by="ID")

# trawl SST for year i-1 and regional SST for year i-1
trawltempprev<-read_csv(str_c(mills_path, "Projects/NSF_CAccel/Data/TrawlTemperatures_Previous2.csv"))
trawltempprev$ID<-format(trawltempprev$ID,scientific=FALSE)

# complete cases
trawltempprev<-trawltempprev[complete.cases(trawltempprev),]

# convert Kelvin to celsius for monthly temperature columns
trawltempprev[,c(3:14,16:27)]<-trawltempprev[,c(3:14,16:27)]-273.15

one-hot encode

###############################################################################################
# change factors to characters
#str(widedat)
widedat$SVVESSEL<-as.character(widedat$SVVESSEL)
widedat$SEASON<-as.character(widedat$SEASON)

# one-hot encode season, vessel, stratum, area
onehot<-data.frame(matrix(NaN,nrow=nrow(widedat),ncol=39))
colnames(onehot)<-paste("is",c(rev(unique(widedat$AREA)),
                               unique(widedat$SEASON),
                               unique(widedat$SVVESSEL),
                               sort(unique(widedat$stratum))),sep="")

# AREA
onehot$isGoM<-ifelse(widedat$AREA=="GoM",1,0)
onehot$isGB<-ifelse(widedat$AREA=="GB",1,0)

# SEASON
onehot$isSPRING<-ifelse(widedat$SEASON=="SPRING",1,0)
onehot$isFALL<-ifelse(widedat$SEASON=="FALL",1,0)

# VESSEL
onehot$isDE<-ifelse(widedat$SVVESSEL=="DE",1,0)
onehot$isAL<-ifelse(widedat$SVVESSEL=="AL",1,0)
onehot$isHB<-ifelse(widedat$SVVESSEL=="HB",1,0)

# STRATUM
strata<-sort(unique(widedat$stratum))
for(i in 1:9157){onehot[i,which(strata==widedat$stratum[i])+7]<-1}
onehot[is.na(onehot)]<-0

# bind together
widedat<-bind_cols(widedat,onehot)

# last bit of reordering columns
widedat<-widedat[,c(1,11,2:4,7,8,9,10,5,6,12:52)]

catch data

###############################################################################################
# catch data for year i-1
gomcatch<-read_csv(str_c(mills_path, "Projects/NSF_CAccel/Data/gom_catch_at_age_19.csv")) # ages 1-9+ years 1982-2018
gbcatch<-read_csv(str_c(mills_path, "Projects/NSF_CAccel/Data/gb_catch_at_age_19.csv")) # ages 1-10+ years 1978-2014
# age overlap is ages 1-9+
# year overlap is years 1982-2014

# remove + signs in column names
colnames(gomcatch)[10]<-"age_9plus"
colnames(gbcatch)[11]<-"age_10"

# for GB, combine age 9 and 10+
gbcatch<-gbcatch%>%mutate(age_9plus=age_9+age_10)
gbcatch<-select(gbcatch,-c(age_9,age_10))

# remove years prior to 1982 for GB
gbcatch<-gbcatch%>%filter(year>=1982)

# remove years after 2014 for GoM
gomcatch<-gomcatch%>%filter(year<=2014)

# rename columns to specify GoM or GB
# catch = c, r1 = GoM, r2 = GB
colnames(gomcatch)[2:10]<-paste("c_r1_",colnames(gomcatch)[2:10],sep="")
colnames(gbcatch)[2:10]<-paste("c_r2_",colnames(gbcatch)[2:10],sep="")
catch<-left_join(gomcatch,gbcatch,by="year")

# bump catch years up by year to pair with current year
catch$year<-catch$year+1

SST at region level year i

###############################################################################################
# Regional SST for year i
SST_GB<-read_csv(str_c(mills_path, "Projects/NSF_CAccel/Data/SSTdata_GB.csv"),
                 col_names=c("year","yranom_gb",
                             "year2","m1_gb","m2_gb","m3_gb","m4_gb","m5_gb","m6_gb","m7_gb","m8_gb","m9_gb","m10_gb","m11_gb","m12_gb",
                             "year3","m1anom_gb","m2anom_gb","m3anom_gb","m4anom_gb","m5anom_gb","m6anom_gb",
                             "m7anom_gb","m8anom_gb","m9anom_gb","m10anom_gb","m11anom_gb","m12anom_gb"))

SST_GoM<-read_csv(str_c(mills_path, "Projects/NSF_CAccel/Data/SSTdata_GOM.csv"),
                  col_names=c("year","yranom_gom",
                              "year2","m1_gom","m2_gom","m3_gom","m4_gom","m5_gom","m6_gom","m7_gom","m8_gom","m9_gom","m10_gom","m11_gom","m12_gom",
                              "year3","m1anom_gom","m2anom_gom","m3anom_gom","m4anom_gom","m5anom_gom","m6anom_gom",
                              "m7anom_gom","m8anom_gom","m9anom_gom","m10anom_gom","m11anom_gom","m12anom_gom"))

# remove extra year columns
SST_GB<-select(SST_GB,-c(year2,year3))
SST_GoM<-select(SST_GoM,-c(year2,year3))

# join data frames together
SST<-left_join(SST_GoM,SST_GB,by="year")

example of over-sampling rare instances and under-sampling common instances

trawl observations with high abundance in medium size class

# combine together
# bind with stratmnvals
exampledata<-bind_cols(widedat,stratmnvals)
# add catch data to survey data
exampledata<-left_join(exampledata,catch,by=c("EST_YEAR"="year"))
# join trawltempprev with exampledata by ID
exampledata<-left_join(exampledata,trawltempprev,by="ID")
# join SST
exampledata<-left_join(exampledata,SST,by=c("EST_YEAR"="year"))
# complete cases only
exampledata<-exampledata[complete.cases(exampledata),]

# order the dataset randomly before dividing data
neworder<-sample(nrow(exampledata))
exampledata<-exampledata[neworder,]

# create 5 folds, each fold gets a turn being the test data
folds<-createFolds(y=1:nrow(exampledata),k=5)

# matrix for assigning folds
mat<-rbind(c(1,2,3,4,5),
           c(2,3,4,5,1),
           c(3,4,5,1,2),
           c(4,5,1,2,3),
           c(5,1,2,3,4))

# 3 folds as training data
train_ind<-c(folds[[mat[1,1]]],folds[[mat[1,2]]],folds[[mat[1,3]]])
# 1 fold as validation
val_ind<-folds[[mat[1,4]]]
# 1 fold as testing
test_ind<-folds[[mat[1,5]]]

# training labels and features
train_labels<-exampledata[train_ind,c(7,8,9)]
train_data<-exampledata[train_ind,c(10:ncol(exampledata))]

# Normalize features
# 1 is TOWDUR
# 2 is AVGDEPTH
# 44:91 is stratified mean abundance year i and year i-1,i-2,i-3
# 92:109 is catch
# 111:122 is R1 temp year i-1 (GoM)
# 124:135 is R2 temp year i-1 (GB)
# 137:148 is GoM temp year i
# 162:173 is GB temp year i
thesecols<-c(1,2,44:91,92:109,111:122,124:135,137:148,162:173)
thesenames<-colnames(train_data)[thesecols]

# validation and testing data is not used when calculating mean and std
# split features to be normalized
sccols<-select(train_data,thesecols)
# calculate mean and std of training data
sccols<-scale(sccols)
# use mean and std from training data to normalize training, validation and testing data
col_means_trainval<-attr(sccols,"scaled:center")
col_stddevs_trainval<-attr(sccols,"scaled:scale")

# training data
# split features to be normalized
sccolstrain<-select(train_data,thesecols)
train_data<-select(train_data,-thesecols)
# normalize
sccolstrain<-scale(sccolstrain,center=col_means_trainval,scale=col_stddevs_trainval)
# put back together
sccolstrain<-as.data.frame(sccolstrain)
colnames(sccolstrain)<-thesenames
train_data<-bind_cols(train_data,sccolstrain)

# oversample rare instances in training data (observations with high abundance in medium size class)
# bind train_labels with train_data
imbal<-bind_cols(train_labels,train_data)

# relevance function for over-sampling probability
# relevance of 0 up to 80th percentile, ramp-up to relevance of 1 by 90th percentile
begin_rel<-quantile(imbal$nmedium,0.8)
end_rel<-quantile(imbal$nmedium,0.9)

# 80th percentile
begin_rel
## 80% 
##   2
# 90th percentile
end_rel
##      90% 
## 7.383333
rel<-matrix(c(begin_rel,0,0,end_rel,1,0),ncol=3,byrow=TRUE)

# oversample using function from package UBL
# Importance Sampling algorithm for imbalanced regression problems

# This function handles imbalanced regression problems using the relevance function provided to
# re-sample the data set. The relevance function is used to introduce replicas of the most important
# examples and to remove the least important examples.

# O parameter expresses importance to over-sampling, default is 0.5, set to 5
# U parameter expresses importance to under-sampling, default is 0.5, set as default
bal<-ImpSampRegress(nmedium~.,imbal,rel=rel,O=5,U=0.5)

# numbers of instances increase from 
nrow(imbal)
## [1] 4658
# to...
nrow(bal)
## [1] 26189
# how many replicates were actually made?
z<-strsplit(row.names(bal),"[.]")

# number of unique instances that were replicated
z1<-c()
for(i in 1:length(z)){
z1<-c(z1,z[[i]][1])
}
length(unique(z1))
## [1] 3875
# range of number of replicates
zz<-c()
for(i in 1:length(z)){
zz<-c(zz,z[[i]][2])
}
boxplot(as.numeric(zz[which(!is.na(zz))]))

# plot difference before/after resampling
plot(imbal$nmedium,main="before resampling")

plot(bal$nmedium,main="after resampling")

# density plots before/after
imbal%>%ggplot(aes(nmedium))+geom_density()+labs(title="before resampling")

bal%>%ggplot(aes(nmedium))+geom_density()+labs(title="after resampling")

# how many common instances were removed?
sum((row.names(imbal)%in%unique(z1))==0)
## [1] 1656
# plot abundance for instances removed
q<-imbal[(row.names(imbal)%in%unique(z1))==0,1:3]
plot(q$nsmall)

plot(q$nmedium)

plot(q$nlarge)

# sometimes nmedium is 0 but other size classes have large abundance values...

5 models so that all data can be set as test data, trained each for 50 epochs

# combine together

# bind with stratmnvals
alldata<-bind_cols(widedat,stratmnvals)

# add catch data to survey data
alldata<-left_join(alldata,catch,by=c("EST_YEAR"="year"))

# join trawltempprev with alldata by ID
alldata<-left_join(alldata,trawltempprev,by="ID")

# join SST
alldata<-left_join(alldata,SST,by=c("EST_YEAR"="year"))

# complete cases only
alldata<-alldata[complete.cases(alldata),]

# order the dataset randomly before dividing data
neworder<-sample(nrow(alldata))
alldata<-alldata[neworder,]

# create 5 folds, each fold gets a turn being the test data
folds<-createFolds(y=1:nrow(alldata),k=5)

# matrix for assigning folds
mat<-rbind(c(1,2,3,4,5),
           c(2,3,4,5,1),
           c(3,4,5,1,2),
           c(4,5,1,2,3),
           c(5,1,2,3,4))

# initialize lists to hold results
learningcurves<-list()
testevals<-list()
res<-list()

# for loop for each fold held out as test data, 5 models
#i<-1
for(i in 1:5){
  
# 3 folds as training data
train_ind<-c(folds[[mat[i,1]]],folds[[mat[i,2]]],folds[[mat[i,3]]])
# 1 fold as validation
val_ind<-folds[[mat[i,4]]]
# 1 fold as testing
test_ind<-folds[[mat[i,5]]]

# training labels and features
train_labels<-alldata[train_ind,c(7,8,9)]
train_data<-alldata[train_ind,c(10:ncol(alldata))]

# validation labels and features
val_labels<-alldata[val_ind,c(7,8,9)]
val_data<-alldata[val_ind,10:ncol(alldata)]

# testing labels and features
test_labels<-alldata[test_ind,c(7,8,9)]
test_data<-alldata[test_ind,10:ncol(alldata)]

# Normalize features
# 1 is TOWDUR
# 2 is AVGDEPTH
# 44:91 is stratified mean abundance year i and year i-1,i-2,i-3
# 92:109 is catch
# 111:122 is R1 temp year i-1 (GoM)
# 124:135 is R2 temp year i-1 (GB)
# 137:148 is GoM temp year i
# 162:173 is GB temp year i
thesecols<-c(1,2,44:91,92:109,111:122,124:135,137:148,162:173)
thesenames<-colnames(train_data)[thesecols]

# validation and testing data is not used when calculating mean and std
# split features to be normalized
sccols<-select(train_data,thesecols)
# calculate mean and std of training data
sccols<-scale(sccols)
# use mean and std from training data to normalize training, validation and testing data
col_means_trainval<-attr(sccols,"scaled:center")
col_stddevs_trainval<-attr(sccols,"scaled:scale")

# training data
# split features to be normalized
sccolstrain<-select(train_data,thesecols)
train_data<-select(train_data,-thesecols)
# normalize
sccolstrain<-scale(sccolstrain,center=col_means_trainval,scale=col_stddevs_trainval)
# put back together
sccolstrain<-as.data.frame(sccolstrain)
colnames(sccolstrain)<-thesenames
train_data<-bind_cols(train_data,sccolstrain)

# validation data
# split features to be normalized
sccolsval<-select(val_data,thesecols)
val_data<-select(val_data,-thesecols)
# normalize
sccolsval<-scale(sccolsval,center=col_means_trainval,scale=col_stddevs_trainval)
# put back together
sccolsval<-as.data.frame(sccolsval)
colnames(sccolsval)<-thesenames
val_data<-bind_cols(val_data,sccolsval)

# testing data
# split features to be normalized
sccolstest<-select(test_data,thesecols)
test_data<-select(test_data,-thesecols)
# normalize
sccolstest<-scale(sccolstest,center=col_means_trainval,scale=col_stddevs_trainval)
# put back together
sccolstest<-as.data.frame(sccolstest)
colnames(sccolstest)<-thesenames
test_data<-bind_cols(test_data,sccolstest)

# oversample rare instances in training data (observations with high abundance in medium size class)
# bind train_labels with train_data
imbal<-bind_cols(train_labels,train_data)

# relevance function
begin_rel<-quantile(imbal$nmedium,0.8)
end_rel<-quantile(imbal$nmedium,0.9)
rel<-matrix(c(begin_rel,0,0,end_rel,1,0),ncol=3,byrow=TRUE)

# oversample
bal<-ImpSampRegress(nmedium~.,imbal,rel=rel,O=5,U=0.5)

# split labels and features
train_labels<-bal[,c(1,2,3)]
train_data<-bal[,c(4:ncol(bal))]

# log transform labels
train_labels<-log(train_labels+1)
val_labels<-log(val_labels+1)
test_labels<-log(test_labels+1)

# convert to matrix
# training
train_data<-as.matrix(train_data)
train_labels<-as.matrix(train_labels)
# validation
val_data<-as.matrix(val_data)
val_labels<-as.matrix(val_labels)
# testing
test_data<-as.matrix(test_data)
test_labels<-as.matrix(test_labels)

# remove all temperature features
#train_data<-train_data[,-c(1,2,42:69,138:185)]
#val_data<-val_data[,-c(1,2,42:69,138:185)]
#test_data<-test_data[,-c(1,2,42:69,138:185)]

# remove area, season, strata one-hot features
#train_data<-train_data[,-c(3:41)]
#val_data<-val_data[,-c(3:41)]
#test_data<-test_data[,-c(3:41)]

# remove anom, anom10, and depth
#train_data<-train_data[,-c(1,2,71)]
#val_data<-val_data[,-c(1,2,71)]
#test_data<-test_data[,-c(1,2,71)]

#######################################################################################

# keras model

# input layer
inputs<-layer_input(shape=dim(train_data)[2])

# outputs are input + dense layers
predictions<-inputs%>%
  layer_dense(units=dim(train_data)[2],activation="relu")%>%
  layer_dense(units=dim(train_data)[2],activation="relu")%>%
  layer_dense(units=dim(train_labels)[2])

# create model
model<-keras_model(inputs=inputs,outputs=predictions)

# compile
model%>%compile(optimizer="adam",loss="mse",metrics="mse")
  
# summary
model%>%summary()

# train model and store training progress learning curves
history<-model%>%fit(train_data,train_labels,epochs=50,validation_data=list(val_data,val_labels),verbose=1)

# model performance on test set
eval<-evaluate(model,test_data,test_labels,verbose=0)

# make predictions
test_predictions<-model%>%predict(test_data)

# back transform
test_predictions<-round(exp(test_predictions)-1,2)
true_labels<-round(exp(test_labels)-1,2)

# combine observed values and predictions
results<-data.frame(observed_small=as.numeric(true_labels[,1]),
                    observed_medium=as.numeric(true_labels[,2]),
                    observed_large=as.numeric(true_labels[,3]),
                    predicted_small=as.numeric(test_predictions[,1]),
                    predicted_medium=as.numeric(test_predictions[,2]),
                    predicted_large=as.numeric(test_predictions[,3]))

# store
learningcurves[[i]]<-history
testevals[[i]]<-eval
res[[i]]<-results

}
## Model: "model"
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## input_1 (InputLayer)             [(None, 185)]                 0           
## ___________________________________________________________________________
## dense (Dense)                    (None, 185)                   34410       
## ___________________________________________________________________________
## dense_1 (Dense)                  (None, 185)                   34410       
## ___________________________________________________________________________
## dense_2 (Dense)                  (None, 3)                     558         
## ===========================================================================
## Total params: 69,378
## Trainable params: 69,378
## Non-trainable params: 0
## ___________________________________________________________________________
## Model: "model_1"
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## input_2 (InputLayer)             [(None, 185)]                 0           
## ___________________________________________________________________________
## dense_3 (Dense)                  (None, 185)                   34410       
## ___________________________________________________________________________
## dense_4 (Dense)                  (None, 185)                   34410       
## ___________________________________________________________________________
## dense_5 (Dense)                  (None, 3)                     558         
## ===========================================================================
## Total params: 69,378
## Trainable params: 69,378
## Non-trainable params: 0
## ___________________________________________________________________________
## Model: "model_2"
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## input_3 (InputLayer)             [(None, 185)]                 0           
## ___________________________________________________________________________
## dense_6 (Dense)                  (None, 185)                   34410       
## ___________________________________________________________________________
## dense_7 (Dense)                  (None, 185)                   34410       
## ___________________________________________________________________________
## dense_8 (Dense)                  (None, 3)                     558         
## ===========================================================================
## Total params: 69,378
## Trainable params: 69,378
## Non-trainable params: 0
## ___________________________________________________________________________
## Model: "model_3"
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## input_4 (InputLayer)             [(None, 185)]                 0           
## ___________________________________________________________________________
## dense_9 (Dense)                  (None, 185)                   34410       
## ___________________________________________________________________________
## dense_10 (Dense)                 (None, 185)                   34410       
## ___________________________________________________________________________
## dense_11 (Dense)                 (None, 3)                     558         
## ===========================================================================
## Total params: 69,378
## Trainable params: 69,378
## Non-trainable params: 0
## ___________________________________________________________________________
## Model: "model_4"
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## input_5 (InputLayer)             [(None, 185)]                 0           
## ___________________________________________________________________________
## dense_12 (Dense)                 (None, 185)                   34410       
## ___________________________________________________________________________
## dense_13 (Dense)                 (None, 185)                   34410       
## ___________________________________________________________________________
## dense_14 (Dense)                 (None, 3)                     558         
## ===========================================================================
## Total params: 69,378
## Trainable params: 69,378
## Non-trainable params: 0
## ___________________________________________________________________________
# learning curves
plot(learningcurves[[1]])

plot(learningcurves[[2]])

plot(learningcurves[[3]])

plot(learningcurves[[4]])

plot(learningcurves[[5]])

# mean squared error
testevals[[1]]
## $loss
## [1] 0.6938727
## 
## $mse
## [1] 0.6938727
testevals[[2]]
## $loss
## [1] 0.6570169
## 
## $mse
## [1] 0.6570169
testevals[[3]]
## $loss
## [1] 0.6123519
## 
## $mse
## [1] 0.612352
testevals[[4]]
## $loss
## [1] 0.7369732
## 
## $mse
## [1] 0.7369732
testevals[[5]]
## $loss
## [1] 0.6990213
## 
## $mse
## [1] 0.6990213
# reshape prediction results for time series plots
# order of folds as testing data for loop is 5,1,2,3,4
ind<-c(folds[[5]],folds[[1]],folds[[2]],folds[[3]],folds[[4]])
resdat<-data.frame(rbind(res[[1]],res[[2]],res[[3]],res[[4]],res[[5]]))
resdat$ind<-ind
resdat$ID<-alldata$ID[ind]
datdat<-left_join(alldata[,1:9],resdat,by="ID")

# GoM SPRING
datdat%>%
  group_by(year=EST_YEAR,area=AREA,season=SEASON)%>%
  summarise(psmall=sum(predicted_small),
            pmedium=sum(predicted_medium),
            plarge=sum(predicted_large),
            osmall=sum(observed_small),
            omedium=sum(observed_medium),
            olarge=sum(observed_large))%>%
  pivot_longer(cols=4:9,names_to=c("type","size"),values_to="abundance",names_sep=1)%>%
  filter(area=="GoM"&season=="SPRING")%>%
  ggplot(aes(x=year,y=abundance,group=type,color=type))+
  geom_line(size=1)+
  labs(title="GoM SPRING")+
  scale_color_discrete(name="type",labels=c("observed","predicted"))+
  facet_grid(factor(size,levels=c("small","medium","large"))~.,scales="free")

# GoM FALL
datdat%>%
  group_by(year=EST_YEAR,area=AREA,season=SEASON)%>%
  summarise(psmall=sum(predicted_small),
            pmedium=sum(predicted_medium),
            plarge=sum(predicted_large),
            osmall=sum(observed_small),
            omedium=sum(observed_medium),
            olarge=sum(observed_large))%>%
  pivot_longer(cols=4:9,names_to=c("type","size"),values_to="abundance",names_sep=1)%>%
  filter(area=="GoM"&season=="FALL")%>%
  ggplot(aes(x=year,y=abundance,group=type,color=type))+
  geom_line(size=1)+
  labs(title="GoM FALL")+
  scale_color_discrete(name="type",labels=c("observed","predicted"))+
  facet_grid(factor(size,levels=c("small","medium","large"))~.,scales="free")

# GB SPRING
datdat%>%
  group_by(year=EST_YEAR,area=AREA,season=SEASON)%>%
  summarise(psmall=sum(predicted_small),
            pmedium=sum(predicted_medium),
            plarge=sum(predicted_large),
            osmall=sum(observed_small),
            omedium=sum(observed_medium),
            olarge=sum(observed_large))%>%
  pivot_longer(cols=4:9,names_to=c("type","size"),values_to="abundance",names_sep=1)%>%
  filter(area=="GB"&season=="SPRING")%>%
  ggplot(aes(x=year,y=abundance,group=type,color=type))+
  geom_line(size=1)+
  labs(title="GB SPRING")+
  scale_color_discrete(name="type",labels=c("observed","predicted"))+
  facet_grid(factor(size,levels=c("small","medium","large"))~.,scales="free")

# GB FALL
datdat%>%
  group_by(year=EST_YEAR,area=AREA,season=SEASON)%>%
  summarise(psmall=sum(predicted_small),
            pmedium=sum(predicted_medium),
            plarge=sum(predicted_large),
            osmall=sum(observed_small),
            omedium=sum(observed_medium),
            olarge=sum(observed_large))%>%
  pivot_longer(cols=4:9,names_to=c("type","size"),values_to="abundance",names_sep=1)%>%
  filter(area=="GB"&season=="FALL")%>%
  ggplot(aes(x=year,y=abundance,group=type,color=type))+
  geom_line(size=1)+
  labs(title="GB FALL")+
  scale_color_discrete(name="type",labels=c("observed","predicted"))+
  facet_grid(factor(size,levels=c("small","medium","large"))~.,scales="free")

# plot
datdat%>%
  pivot_longer(cols=10:15,names_to=c("type","size"),values_to="abundance",names_pattern="(.*)_(.*)")%>%
  group_by(type,size)%>%
  summarise(abundance=sum(abundance))%>%
  ggplot(aes(x=factor(type,levels=c("observed","predicted")),y=abundance,fill=factor(size,levels=c("small","medium","large"))))+
  geom_bar(position="dodge",stat="identity")+
  labs(x="type")+
  scale_fill_discrete("size")

ggplot()+geom_point(data=datdat,aes(x=observed_small,y=predicted_small),alpha=0.1)

ggplot()+geom_point(data=datdat,aes(x=observed_medium,y=predicted_medium),alpha=0.1)

ggplot()+geom_point(data=datdat,aes(x=observed_large,y=predicted_large),alpha=0.1)

variable importance plot

library(vip)

# from https://bgreenwell.github.io/pdp/articles/pdp-example-tensorflow.html

# vip randomly permutes the values of each feature and records the drop in training performance

# prediction function wrapper, two arguments: object (the fitted model) and newdata
# The function needs to return a vector of predictions (one for each observation)
pred_wrapper<-function(object,newdata){
  predict(object,x=as.matrix(newdata))[,2]%>% # column 2 is medium size class
  as.vector()
}

# use training data prior to resampling
orig_train_data<-imbal[,c(4:ncol(bal))]
orig_train_labels<-imbal[,c(1,2,3)]

# permutation-based VIP for the fitted network
p1<-vip(object=model,                          # fitted model
        method="permute",                      # permutation-based VI scores
        num_features=10,                       # plots top 10 features
        pred_wrapper=pred_wrapper,             # user-defined prediction function
        train=as.data.frame(orig_train_data) , # training data
        target=orig_train_labels[,2],          # response values used for training (column 2 is medium size class)
        metric="mse",                          # evaluation metric
        progress="text")                       # request a text-based progress bar

# plot
p1

ICE curves (individual conditional expectation)

library(pdp)

# from https://christophm.github.io/interpretable-ml-book/ice.html

# Individual Conditional Expectation (ICE) plots display one line per instance that shows how the instance's prediction changes when a feature changes

# a partial dependence plot is an overall average of the ICE lines

# use AVGDEPTH feature
p2<-partial(object=model,
            pred.var="AVGDEPTH",
            pred.fun=pred_wrapper,
            train=as.data.frame(orig_train_data))

# use anom feature
p3<-partial(object=model,
            pred.var="anom",
            pred.fun=pred_wrapper,
            train=as.data.frame(orig_train_data))

# use anom10 feature
p4<-partial(object=model,
            pred.var="anom10",
            pred.fun=pred_wrapper,
            train=as.data.frame(orig_train_data))

# before unscale and back-transformation
grid.arrange(p2%>%autoplot(alpha=0.1),
             p3%>%autoplot(alpha=0.1),
             p4%>%autoplot(alpha=0.1),
             ncol=3)

# unscale
p2$AVGDEPTH<-(p2$AVGDEPTH*col_stddevs_trainval[2])+col_means_trainval[2]

# back transform predictions
p2$yhat<-exp(p2$yhat)-1
p3$yhat<-exp(p3$yhat)-1
p4$yhat<-exp(p4$yhat)-1

# after unscale and back-transformation
grid.arrange(p2%>%autoplot(alpha=0.1),
             p3%>%autoplot(alpha=0.1),
             p4%>%autoplot(alpha=0.1),
             ncol=3)

partial dependence plot

# partial dependence plot shows marginal effect one or two features have on the predicted outcome 

# modify wrapper to return average prediction across all observations
pred_wrapper<-function(object,newdata){
  predict(object,x=as.matrix(newdata))[,2]%>% # column 2 is medium size class
  as.vector()%>%
  mean()
}

# partial dependence plot
p5<-partial(object=model,
            pred.var=c("AVGDEPTH","anom"),
            chull=TRUE,                       # restrict predictions to region of joint values
            pred.fun=pred_wrapper,
            train=as.data.frame(orig_train_data))

# before unscale and back-transformation
p5%>%autoplot()

# unscale
p5$AVGDEPTH<-(p5$AVGDEPTH*col_stddevs_trainval[2])+col_means_trainval[2]

# back transform predictions
p5$yhat<-exp(p5$yhat)-1

# after unscale and back-transformation
p5%>%autoplot()